;;;; -*- mode: lisp; indent-tabs-mode: nil -*-
;;;; argon2.lisp -- implementation of the Argon2 key derivation function

;;; Based on the Argon2 implementation present in the Monocypher
;;; crypto library (http://loup-vaillant.fr/projects/monocypher/)

(in-package :crypto)


(defclass argon2 ()
  ((block :accessor argon2-block :type (simple-array (unsigned-byte 64) (128)))
   (pass-number :accessor argon2-pass-number)
   (slice-number :accessor argon2-slice-number)
   (nb-blocks :accessor argon2-nb-blocks)
   (block-count :accessor argon2-block-count)
   (nb-iterations :accessor argon2-nb-iterations)
   (counter :accessor argon2-counter)
   (offset :accessor argon2-offset)
   (additional-key :accessor argon2-additional-key :type (simple-array (unsigned-byte 8) (*)))
   (additional-data :accessor argon2-additional-data :type (simple-array (unsigned-byte 8) (*)))
   (work-area :accessor argon2-work-area :type (simple-array (unsigned-byte 64) (*)))
   (digester :accessor argon2-digester)))

(defclass argon2i (argon2)
  ())

(defclass argon2d (argon2)
  ())

(defclass argon2id (argon2)
  ())

(defconstant +argon2-block-size+ 128)

(deftype argon2-block ()
  '(simple-array (unsigned-byte 64) (128)))


(defun argon2-load-block (b bytes)
  (declare (type (simple-array (unsigned-byte 64) (*)) b)
           (type (simple-array (unsigned-byte 8) (*)) bytes))
  (dotimes (i +argon2-block-size+)
    (setf (aref b i) (ub64ref/le bytes (* 8 i))))
  (values))

(defun argon2-store-block (bytes b &key (start2 0))
  (declare (type (simple-array (unsigned-byte 64) (*)) b)
           (type (simple-array (unsigned-byte 8) (*)) bytes))
  (dotimes (i +argon2-block-size+)
    (setf (ub64ref/le bytes (* 8 i)) (aref b (+ (* +argon2-block-size+ start2) i))))
  (values))

(defun argon2-copy-block (b1 b2 &key (start1 0) (start2 0))
  (declare (type (simple-array (unsigned-byte 64) (*)) b1 b2))
  (dotimes (i +argon2-block-size+)
    (setf (aref b1 (+ (* +argon2-block-size+ start1) i))
          (aref b2 (+ (* +argon2-block-size+ start2) i))))
  (values))

(defun argon2-xor-block (b1 b2 &key (start1 0) (start2 0))
  (declare (type (simple-array (unsigned-byte 64) (*)) b1 b2))
  (dotimes (i +argon2-block-size+)
    (setf (aref b1 (+ (* +argon2-block-size+ start1) i))
          (logxor (aref b1 (+ (* +argon2-block-size+ start1) i))
                  (aref b2 (+ (* +argon2-block-size+ start2) i)))))
  (values))

(defun argon2-update-digester-32 (digester input)
  (update-mac digester (integer-to-octets input :n-bits 32 :big-endian nil))
  (values))

(defun argon2-extended-hash (state digest digest-size input input-size)
  (declare (type (simple-array (unsigned-byte 8) (*)) digest input))
  (let ((no-key (make-array 0 :element-type '(unsigned-byte 8)))
        (digester (argon2-digester state)))
    (reinitialize-instance digester :key no-key :digest-length (min digest-size 64))
    (argon2-update-digester-32 digester digest-size)
    (update-mac digester input :end input-size)
    (produce-mac digester :digest digest)
    (when (> digest-size 64)
      (let ((r (- (ceiling digest-size 32) 2))
            (i 1)
            (in 0)
            (out 32))
        (loop while (< i r) do
          (reinitialize-instance digester :key no-key :digest-length 64)
          (update-mac digester digest :start in :end (+ in 64))
          (produce-mac digester :digest digest :digest-start out)
          (incf i 1)
          (incf in 32)
          (incf out 32))
        (reinitialize-instance digester :key no-key :digest-length (- digest-size (* 32 r)))
        (update-mac digester digest :start in :end (+ in 64))
        (produce-mac digester :digest digest :digest-start out))))
  (values))

(defmacro argon2-g (a b c d)
  `(setf ,a (mod64+ ,a (mod64+ ,b (mod64* 2 (mod64* (logand ,a #xffffffff) (logand ,b #xffffffff)))))
         ,d (ror64 (logxor ,d ,a) 32)
         ,c (mod64+ ,c (mod64+ ,d (mod64* 2 (mod64* (logand ,c #xffffffff) (logand ,d #xffffffff)))))
         ,b (ror64 (logxor ,b ,c) 24)
         ,a (mod64+ ,a (mod64+ ,b (mod64* 2 (mod64* (logand ,a #xffffffff) (logand ,b #xffffffff)))))
         ,d (ror64 (logxor ,d ,a) 16)
         ,c (mod64+ ,c (mod64+ ,d (mod64* 2 (mod64* (logand ,c #xffffffff) (logand ,d #xffffffff)))))
         ,b (ror64 (logxor ,b ,c) 63)))

(defmacro argon2-round (v0 v1 v2 v3 v4 v5 v6 v7 v8 v9 v10 v11 v12 v13 v14 v15)
  `(progn
     (argon2-g ,v0 ,v4 ,v8 ,v12)
     (argon2-g ,v1 ,v5 ,v9 ,v13)
     (argon2-g ,v2 ,v6 ,v10 ,v14)
     (argon2-g ,v3 ,v7 ,v11 ,v15)
     (argon2-g ,v0 ,v5 ,v10 ,v15)
     (argon2-g ,v1 ,v6 ,v11 ,v12)
     (argon2-g ,v2 ,v7 ,v8 ,v13)
     (argon2-g ,v3 ,v4 ,v9 ,v14)))

(defun argon2-g-rounds (work-block)
  (declare (type argon2-block work-block))
  (loop for i from 0 below 128 by 16 do
    (argon2-round (aref work-block i)
                  (aref work-block (+ i 1))
                  (aref work-block (+ i 2))
                  (aref work-block (+ i 3))
                  (aref work-block (+ i 4))
                  (aref work-block (+ i 5))
                  (aref work-block (+ i 6))
                  (aref work-block (+ i 7))
                  (aref work-block (+ i 8))
                  (aref work-block (+ i 9))
                  (aref work-block (+ i 10))
                  (aref work-block (+ i 11))
                  (aref work-block (+ i 12))
                  (aref work-block (+ i 13))
                  (aref work-block (+ i 14))
                  (aref work-block (+ i 15))))
  (loop for i from 0 below 16 by 2 do
    (argon2-round (aref work-block i)
                  (aref work-block (+ i 1))
                  (aref work-block (+ i 16))
                  (aref work-block (+ i 17))
                  (aref work-block (+ i 32))
                  (aref work-block (+ i 33))
                  (aref work-block (+ i 48))
                  (aref work-block (+ i 49))
                  (aref work-block (+ i 64))
                  (aref work-block (+ i 65))
                  (aref work-block (+ i 80))
                  (aref work-block (+ i 81))
                  (aref work-block (+ i 96))
                  (aref work-block (+ i 97))
                  (aref work-block (+ i 112))
                  (aref work-block (+ i 113))))
  (values))

(defun argon2-g-copy (work-area r x y)
  (declare (type (simple-array (unsigned-byte 64) (*)) work-area))
  (let ((tmp (make-array +argon2-block-size+ :element-type '(unsigned-byte 64))))
    (declare (type argon2-block tmp)
             (dynamic-extent tmp))
    (argon2-copy-block tmp work-area :start2 x)
    (argon2-xor-block tmp work-area :start2 y)
    (argon2-copy-block work-area tmp :start1 r)
    (argon2-g-rounds tmp)
    (argon2-xor-block work-area tmp :start1 r))
  (values))

(defun argon2-g-xor (work-area r x y)
  (declare (type (simple-array (unsigned-byte 64) (*)) work-area))
  (let ((tmp (make-array +argon2-block-size+ :element-type '(unsigned-byte 64))))
    (declare (type argon2-block tmp)
             (dynamic-extent tmp))
    (argon2-copy-block tmp work-area :start2 x)
    (argon2-xor-block tmp work-area :start2 y)
    (argon2-xor-block work-area tmp :start1 r)
    (argon2-g-rounds tmp)
    (argon2-xor-block work-area tmp :start1 r))
  (values))

(defun argon2-unary-g (work-block)
  (declare (type argon2-block work-block))
  (let ((tmp (make-array +argon2-block-size+ :element-type '(unsigned-byte 64))))
    (declare (type argon2-block tmp)
             (dynamic-extent tmp))
    (argon2-copy-block tmp work-block)
    (argon2-g-rounds work-block)
    (argon2-xor-block work-block tmp))
  (values))

(defun argon2i-gidx-refresh (state)
  (let ((b (argon2-block state)))
    (setf (aref b 0) (argon2-pass-number state)
          (aref b 1) 0
          (aref b 2) (argon2-slice-number state)
          (aref b 3) (argon2-nb-blocks state)
          (aref b 4) (argon2-nb-iterations state)
          (aref b 5) (etypecase state
                       (argon2i 1)
                       (argon2id 2))
          (aref b 6) (argon2-counter state))
    (fill b 0 :start 7)
    (argon2-unary-g b)
    (argon2-unary-g b)
    (values)))

(defun argon2i-gidx-init (state pass-number slice-number nb-blocks nb-iterations)
  (setf (argon2-pass-number state) pass-number
        (argon2-slice-number state) slice-number
        (argon2-nb-blocks state) nb-blocks
        (argon2-nb-iterations state) nb-iterations
        (argon2-counter state) 0)
  (if (and (zerop pass-number) (zerop slice-number))
      (progn
        (setf (argon2-offset state) 2)
        (incf (argon2-counter state))
        (argon2i-gidx-refresh state))
      (setf (argon2-offset state) 0))
  (values))

(defun argon2i-gidx-next (state)
  (when (zerop (mod (argon2-offset state) +argon2-block-size+))
    (incf (argon2-counter state))
    (argon2i-gidx-refresh state))
  (let* ((offset (argon2-offset state))
         (index (mod offset +argon2-block-size+))
         (first-pass (zerop (argon2-pass-number state)))
         (nb-blocks (argon2-nb-blocks state))
         (slice-size (floor nb-blocks 4))
         (slice-number (argon2-slice-number state))
         (nb-segments (if first-pass slice-number 3))
         (area-size (- (+ (* nb-segments slice-size) offset) 1))
         (next-slice (* (mod (+ slice-number 1) 4) slice-size))
         (start-pos (if first-pass 0 next-slice))
         (j1 (logand (aref (argon2-block state) index) #xffffffff))
         (x (ash (* j1 j1) -32))
         (y (ash (* area-size x) -32))
         (z (- area-size 1 y)))
    (incf (argon2-offset state))
    (mod (+ start-pos z) nb-blocks)))

(defun argon2d-gidx-init (state pass-number slice-number nb-blocks nb-iterations)
  (setf (argon2-pass-number state) pass-number
        (argon2-slice-number state) slice-number
        (argon2-nb-blocks state) nb-blocks
        (argon2-nb-iterations state) nb-iterations
        (argon2-counter state) 0)
  (if (and (zerop pass-number) (zerop slice-number))
      (setf (argon2-offset state) 2)
      (setf (argon2-offset state) 0))
  (values))

(defun argon2d-gidx-next (state previous-block)
  (let* ((offset (argon2-offset state))
         (index (* +argon2-block-size+ previous-block))
         (first-pass (zerop (argon2-pass-number state)))
         (nb-blocks (argon2-nb-blocks state))
         (slice-size (floor nb-blocks 4))
         (slice-number (argon2-slice-number state))
         (nb-segments (if first-pass slice-number 3))
         (area-size (- (+ (* nb-segments slice-size) offset) 1))
         (next-slice (* (mod (+ slice-number 1) 4) slice-size))
         (start-pos (if first-pass 0 next-slice))
         (j1 (logand (aref (argon2-work-area state) index) #xffffffff))
         (x (ash (* j1 j1) -32))
         (y (ash (* area-size x) -32))
         (z (- area-size 1 y)))
    (incf (argon2-offset state))
    (mod (+ start-pos z) nb-blocks)))

(defmethod shared-initialize ((kdf argon2) slot-names &rest initargs
                              &key block-count additional-key additional-data &allow-other-keys)
  (declare (ignore initargs))
  (let ((no-data (make-array 0 :element-type '(unsigned-byte 8))))
    (setf (argon2-block kdf) (make-array +argon2-block-size+
                                         :element-type '(unsigned-byte 64))
          (argon2-block-count kdf) (max 8 (or block-count 4096))
          (argon2-additional-key kdf) (or additional-key no-data)
          (argon2-additional-data kdf) (or additional-data no-data)
          (argon2-work-area kdf) (make-array (* +argon2-block-size+ block-count)
                                             :element-type '(unsigned-byte 64))
          (argon2-digester kdf) (make-mac :blake2-mac no-data)))
  kdf)

(defmethod derive-key ((kdf argon2) passphrase salt iteration-count key-length)
  (declare (type (simple-array (unsigned-byte 8) (*)) passphrase salt))
  (when (or (< key-length 4) (< iteration-count 1) (< (length salt) 8))
    (error 'unsupported-argon2-parameters))
  (setf (argon2-nb-iterations kdf) iteration-count)
  (let ((data-independent-p (or (typep kdf 'argon2i) (typep kdf 'argon2id)))
        (work-area (argon2-work-area kdf))
        (block-count (argon2-block-count kdf))
        (additional-key (argon2-additional-key kdf))
        (additional-data (argon2-additional-data kdf))
        (digester (argon2-digester kdf))
        (no-key (make-array 0 :element-type '(unsigned-byte 8)))
        (tmp-area (make-array 1024 :element-type '(unsigned-byte 8))))
    (declare (type (simple-array (unsigned-byte 64) (*)) work-area)
             (type (simple-array (unsigned-byte 8) (1024)) tmp-area)
             (dynamic-extent tmp-area))
    (reinitialize-instance digester :key no-key :digest-length 64)
    (argon2-update-digester-32 digester 1)
    (argon2-update-digester-32 digester key-length)
    (argon2-update-digester-32 digester block-count)
    (argon2-update-digester-32 digester iteration-count)
    (argon2-update-digester-32 digester #x13)
    (argon2-update-digester-32 digester (etypecase kdf
                                          (argon2d 0)
                                          (argon2i 1)
                                          (argon2id 2)))
    (argon2-update-digester-32 digester (length passphrase))
    (update-mac digester passphrase)
    (argon2-update-digester-32 digester (length salt))
    (update-mac digester salt)
    (argon2-update-digester-32 digester (length additional-key))
    (update-mac digester additional-key)
    (argon2-update-digester-32 digester (length additional-data))
    (update-mac digester additional-data)
    (let ((initial-hash (make-array 72 :element-type '(unsigned-byte 8)))
          (tmp-block (make-array +argon2-block-size+ :element-type '(unsigned-byte 64))))
      (declare (type (simple-array (unsigned-byte 8) (72)) initial-hash)
               (type argon2-block tmp-block)
               (dynamic-extent initial-hash tmp-block))
      (produce-mac digester :digest initial-hash)

      (setf (ub32ref/le initial-hash 64) 0
            (ub32ref/le initial-hash 68) 0)
      (argon2-extended-hash kdf tmp-area 1024 initial-hash 72)
      (argon2-load-block tmp-block tmp-area)
      (argon2-copy-block work-area tmp-block)

      (setf (ub32ref/le initial-hash 64) 1)
      (argon2-extended-hash kdf tmp-area 1024 initial-hash 72)
      (argon2-load-block tmp-block tmp-area)
      (argon2-copy-block work-area tmp-block :start1 1))

    (let* ((nb-blocks (- block-count (mod block-count 4)))
           (segment-size (floor nb-blocks 4)))
      (dotimes (pass-number iteration-count)
        (let ((first-pass (zerop pass-number)))
          (dotimes (segment 4)
            (when (and (= segment 2) (typep kdf 'argon2id))
              (setf data-independent-p nil))
            (if data-independent-p
                (argon2i-gidx-init kdf pass-number segment nb-blocks iteration-count)
                (argon2d-gidx-init kdf pass-number segment nb-blocks iteration-count))
            (let* ((start-offset (if (and first-pass (zerop segment)) 2 0))
                   (segment-start (+ (* segment segment-size) start-offset))
                   (segment-end (* (+ segment 1) segment-size)))
              (loop for current-block from segment-start below segment-end do
                (let* ((previous-block (if (zerop current-block)
                                           (- nb-blocks 1)
                                           (- current-block 1)))
                       (reference-block (if data-independent-p
                                            (argon2i-gidx-next kdf)
                                            (argon2d-gidx-next kdf previous-block))))
                  (if first-pass
                      (argon2-g-copy work-area current-block previous-block reference-block)
                      (argon2-g-xor work-area current-block previous-block reference-block))))))))
      (let ((hash (make-array key-length :element-type '(unsigned-byte 8))))
        (argon2-store-block tmp-area work-area :start2 (- nb-blocks 1))
        (argon2-extended-hash kdf hash key-length tmp-area 1024)
        hash))))
